{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/avikumart/LLM-GenAI-Transformers-Notebooks/blob/main/NL2SQL_With_LlamaIndex.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TfTfxFyiwvd_"
      },
      "source": [
        "## Setting up logging credentials"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "asnqVxCzdQfI"
      },
      "outputs": [],
      "source": [
        "import logging\n",
        "import sys\n",
        "\n",
        "logging.basicConfig(stream=sys.stdout, level=logging.INFO, force=True)\n",
        "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))\n",
        "from IPython.display import Markdown, display"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XqskaZZfmKyZ",
        "outputId": "252f1d6e-3243-469a-cca5-451420016864"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pymysql in /usr/local/lib/python3.10/dist-packages (1.1.0)\n"
          ]
        }
      ],
      "source": [
        "!pip install pymysql"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o34KEKFSqu4M"
      },
      "outputs": [],
      "source": [
        "!pip install llama_index"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Djyt4ntlA2kO"
      },
      "outputs": [],
      "source": [
        "pip install --force-reinstall 'sqlalchemy<2.0.0'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gVol6tE8mVRi"
      },
      "outputs": [],
      "source": [
        "import pymysql"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4XE0pxRww7No"
      },
      "source": [
        "## Connect to Database"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bb_dKSVXhxt3"
      },
      "outputs": [],
      "source": [
        "db_user = \"user_name\"\n",
        "db_password = \"your_passsword\"\n",
        "db_host = \"your_host_name\"\n",
        "db_name = \"your_database_name\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Gsv3AI5AlwUc",
        "outputId": "bee2205b-aa33-4bc2-d612-d6d2ee3c6160"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "('Customers',)\n",
            "('OrderDetails',)\n",
            "('Orders',)\n",
            "('info',)\n",
            "('table_36370395d9314572970b325bee42817d',)\n",
            "('table_59af3c5a5388482193f88093fc2eaa36',)\n"
          ]
        }
      ],
      "source": [
        "from sqlalchemy import create_engine, text\n",
        "\n",
        "# Construct the connection string\n",
        "connection_string = f\"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}\"\n",
        "\n",
        "# Create an engine instance\n",
        "engine = create_engine(connection_string)\n",
        "\n",
        "# Test the connection using raw SQL\n",
        "with engine.connect() as connection:\n",
        "    result = connection.execute(text(\"show tables\"))\n",
        "    for row in result:\n",
        "        print(row)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UeIX0zRWxBVd"
      },
      "source": [
        "## Load database to llamaIndex SQLDatabase"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WpzC82mtrxT5",
        "outputId": "2fd4a8a7-19b1-4ad7-a2c3-5b10f91b733d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "INFO:numexpr.utils:NumExpr defaulting to 2 threads.\n",
            "NumExpr defaulting to 2 threads.\n"
          ]
        }
      ],
      "source": [
        "from llama_index.core import SQLDatabase"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "O79JFO7Tr3-E",
        "outputId": "864f4f9f-3bfa-4a54-b198-d343b426573a"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<llama_index.core.utilities.sql_wrapper.SQLDatabase at 0x7de656bd2a40>"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "tables = ['table_36370395d9314572970b325bee42817d', 'table_59af3c5a5388482193f88093fc2eaa36']\n",
        "sql_database = SQLDatabase(engine, include_tables=tables, sample_rows_in_table_info=2)\n",
        "sql_database"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AXf7FLDUr_qB",
        "outputId": "1f12f61f-4d1e-400c-955c-bdd33a712083"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'Customers',\n",
              " 'OrderDetails',\n",
              " 'Orders',\n",
              " 'info',\n",
              " 'table_36370395d9314572970b325bee42817d',\n",
              " 'table_59af3c5a5388482193f88093fc2eaa36'}"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sql_database._all_tables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        },
        "id": "YDdsRtylsCvB",
        "outputId": "2b01a53b-5c03-42bb-b6d0-d4477d2f2d41"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div style=\"max-width:800px; border: 1px solid var(--colab-border-color);\"><style>\n",
              "      pre.function-repr-contents {\n",
              "        overflow-x: auto;\n",
              "        padding: 8px 12px;\n",
              "        max-height: 500px;\n",
              "      }\n",
              "\n",
              "      pre.function-repr-contents.function-repr-contents-collapsed {\n",
              "        cursor: pointer;\n",
              "        max-height: 100px;\n",
              "      }\n",
              "    </style>\n",
              "    <pre style=\"white-space: initial; background:\n",
              "         var(--colab-secondary-surface-color); padding: 8px 12px;\n",
              "         border-bottom: 1px solid var(--colab-border-color);\"><b>llama_index.core.utilities.sql_wrapper.SQLDatabase.get_single_table_info</b><br/>def get_single_table_info(table_name: str) -&gt; str</pre><pre class=\"function-repr-contents function-repr-contents-collapsed\" style=\"\"><a class=\"filepath\" style=\"display:none\" href=\"#\">/usr/local/lib/python3.10/dist-packages/llama_index/core/utilities/sql_wrapper.py</a>Get table info for a single table.</pre>\n",
              "      <script>\n",
              "      if (google.colab.kernel.accessAllowed && google.colab.files && google.colab.files.view) {\n",
              "        for (const element of document.querySelectorAll('.filepath')) {\n",
              "          element.style.display = 'block'\n",
              "          element.onclick = (event) => {\n",
              "            event.preventDefault();\n",
              "            event.stopPropagation();\n",
              "            google.colab.files.view(element.textContent, 150);\n",
              "          };\n",
              "        }\n",
              "      }\n",
              "      for (const element of document.querySelectorAll('.function-repr-contents')) {\n",
              "        element.onclick = (event) => {\n",
              "          event.preventDefault();\n",
              "          event.stopPropagation();\n",
              "          element.classList.toggle('function-repr-contents-collapsed');\n",
              "        };\n",
              "      }\n",
              "      </script>\n",
              "      </div>"
            ],
            "text/plain": [
              "<bound method SQLDatabase.get_single_table_info of <llama_index.core.utilities.sql_wrapper.SQLDatabase object at 0x7de656bd2a40>>"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sql_database.get_single_table_info"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uMncMEfVxIBT"
      },
      "source": [
        "## Connect to OpenAI account"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sZZ9L0uQsQpA"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import openai"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zYtPpql9wrw5"
      },
      "outputs": [],
      "source": [
        "# enter your openai key\n",
        "# openai.api_key = \"your_openai_key\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7ywQn3vHw44-"
      },
      "outputs": [],
      "source": [
        "!pip install llama-index-callbacks-aim\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Cy6D7p6xTtg"
      },
      "source": [
        "## Connect to LLM and initialize the service context"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SIzfNJajwunr"
      },
      "outputs": [],
      "source": [
        "import tiktoken\n",
        "from llama_index.core.callbacks import CallbackManager, TokenCountingHandler\n",
        "token_counter = TokenCountingHandler(\n",
        "    tokenizer=tiktoken.encoding_for_model(\"gpt-3.5-turbo\").encode\n",
        ")\n",
        "\n",
        "callback_manager = CallbackManager([token_counter])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fSyv99l6wzd6",
        "outputId": "ec73509f-4ff6-4d1c-a7a0-1aa7377fc087"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "<ipython-input-17-611c876a3df4>:7: DeprecationWarning: Call to deprecated class method from_defaults. (ServiceContext is deprecated, please use `llama_index.settings.Settings` instead.) -- Deprecated since version 0.10.0.\n",
            "  service_context = ServiceContext.from_defaults(\n"
          ]
        }
      ],
      "source": [
        "from llama_index.core.indices.service_context import ServiceContext\n",
        "from llama_index.embeddings.openai import OpenAIEmbedding\n",
        "#import llama_index.indices.prompt_helper.PromptHelper\n",
        "from llama_index.llms.openai import OpenAI\n",
        "llm = OpenAI(temperature=0.1, model=\"gpt-3.5-turbo\")\n",
        "\n",
        "service_context = ServiceContext.from_defaults(\n",
        "  llm=llm,callback_manager=callback_manager\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LNC8smcaxZCl"
      },
      "source": [
        "## Create SQL table node mapping"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vbkj9BEk0Nqt"
      },
      "outputs": [],
      "source": [
        "#creating SQL table node mapping\n",
        "from llama_index.core import VectorStoreIndex\n",
        "from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema\n",
        "import pandas as pd\n",
        "\n",
        "# list all the tables from database and crate table schema for prompt to LLM\n",
        "tables = list(sql_database._all_tables)\n",
        "table_node_mapping = SQLTableNodeMapping(sql_database)\n",
        "table_schema_objs = []\n",
        "for table in tables:\n",
        "    table_schema_objs.append((SQLTableSchema(table_name = table)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RXuln6id55cT",
        "outputId": "eaddc7c9-21e5-41ab-9a7f-b631bd7f10a3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[SQLTableSchema(table_name='table_59af3c5a5388482193f88093fc2eaa36', context_str=None), SQLTableSchema(table_name='Orders', context_str=None), SQLTableSchema(table_name='table_36370395d9314572970b325bee42817d', context_str=None), SQLTableSchema(table_name='OrderDetails', context_str=None), SQLTableSchema(table_name='info', context_str=None), SQLTableSchema(table_name='Customers', context_str=None)]\n"
          ]
        }
      ],
      "source": [
        "print(table_schema_objs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gyVYfOfoxh9q"
      },
      "source": [
        "## Initialize SQL Table Retrieval Engine"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mYTMGRF-4kpY"
      },
      "outputs": [],
      "source": [
        "from llama_index.core.indices.struct_store.sql_query import SQLTableRetrieverQueryEngine\n",
        "\n",
        "obj_index = ObjectIndex.from_objects(\n",
        "    table_schema_objs,\n",
        "    table_node_mapping,\n",
        "    VectorStoreIndex\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-1-6vmM95mcd"
      },
      "outputs": [],
      "source": [
        "query_engine = SQLTableRetrieverQueryEngine(\n",
        "    sql_database, obj_index.as_retriever(similarity_top_k=3), service_context=service_context\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q2jaDosADoMK"
      },
      "outputs": [],
      "source": [
        "response = query_engine.query(\"How many people have previous work experience?\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4cKVpcUkDqqy"
      },
      "outputs": [],
      "source": [
        "print(response)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vHLOaQMoxtv0"
      },
      "source": [
        "## Few-shot examples prompting"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n7NT73COxxTp"
      },
      "outputs": [],
      "source": [
        "# select few examples\n",
        "  examples = [\n",
        "      {\n",
        "          \"input\": \"List all customers in France with a credit limit over 20,000.\",\n",
        "          \"query\": \"SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;\"\n",
        "      },\n",
        "      {\n",
        "          \"input\": \"Get the highest payment amount made by any customer.\",\n",
        "          \"query\": \"SELECT MAX(amount) FROM payments;\"\n",
        "      },\n",
        "     .....\n",
        "  ]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uzNlsQ7rx64O"
      },
      "outputs": [],
      "source": [
        "# few shot examples prompt using llamaIndex prompt tempelates\n",
        " from llama_index.core import PromptTemplate\n",
        " from llama_index.llms.openai import OpenAI\n",
        "\n",
        " # Example prompt for SQL query generation\n",
        " qa_prompt_tmpl_str = \"\"\"\\\n",
        " Context information is below.\n",
        " ---------------------\n",
        " (\"human\", \"{input}\\nSQLQuery:\"),\n",
        " (\"ai\", \"{query}\") \\,\n",
        " \"\"\"\n",
        " # Variable mapping to generate templates\n",
        " template_var_mappings = {\"user_queries\": \"input\", \"ai_sql\": \"query\"}\n",
        "\n",
        " # Prompt template\n",
        " prompt_tmpl = PromptTemplate(\n",
        "     qa_prompt_tmpl_str, template_var_mappings=template_var_mappings\n",
        " )\n",
        "\n",
        " # format prompt with user queries and sql queries\n",
        " fmt_prompt = partial_prompt_tmpl.format(\n",
        "     user_queries=\"List all customers in France with a credit limit over 20,000.\",\n",
        "     ai_sql=\"SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;\",\n",
        " )\n",
        " print(fmt_prompt)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dsjj-ZYXyJxj"
      },
      "outputs": [],
      "source": [
        "# Dynamic few shot examples selection\n",
        "from llama_index.core import VectorStoreIndex\n",
        "from llama_index.llms.openai import OpenAI\n",
        "\n",
        "gpt35_llm = OpenAI(model=\"gpt-3.5-turbo\")\n",
        "\n",
        "# set up an index of few shot examples\n",
        "index = VectorStoreIndex.from_documents(examples)\n",
        "\n",
        "# query string\n",
        "query_str = \"How many employees we have?\"\n",
        "\n",
        "# retrieve vectors from index to get most relevant examples\n",
        "query_engine = index.as_query_engine(similarity_top_k=2, llm=gpt35_llm)\n",
        "\n",
        "# print the few shot examples further used for few shot prompting\n",
        "response_prompt = query_engine.query(query_str)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "include_colab_link": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
